{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fccf87d1d30>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from trainer import Trainer\n",
    "from model import Model\n",
    "from torch.utils.data import DataLoader\n",
    "from decomposition import Decomposition\n",
    "from turbinedataset import TurbineDataset\n",
    "from torch.nn import MSELoss\n",
    "from torch.optim import Adam\n",
    "import torch\n",
    "torch.autograd.set_detect_anomaly(True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "target_col = \"speed\"\n",
    "trainrate = 0.8\n",
    "testrate = 0.1\n",
    "lr = 0.001\n",
    "epoch = 200\n",
    "seq = 24\n",
    "features = 8\n",
    "batch_size = 128\n",
    "valepoch = 1\n",
    "target_id = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_df = []\n",
    "for i in range(1, 16):\n",
    "    if i == 9:\n",
    "        continue\n",
    "    df = pd.read_csv(f\"/home/duxiangyu/data/turbine_data/data_with_turbine_features/processed_data/WuXingLing_WuXingLing{i:03d}_20231105000000_20231107000000_SCADA温数据.csv\")\n",
    "    if i == target_id:\n",
    "        target = df[target_col]\n",
    "    data_df.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "decomp = Decomposition(seq)\n",
    "data_base, data_diff = decomp.decomp(data_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "datalen = len(target)\n",
    "train_df_base = [base[:int(datalen * 0.8)] for base in data_base]\n",
    "test_df_base = [base[-int(datalen * 0.1):] for base in data_base]\n",
    "val_df_base = [base[int(datalen * 0.8):-int(datalen * 0.1)] for base in data_base]\n",
    "train_df_diff = [diff[:int(datalen * 0.8)] for diff in data_diff]\n",
    "test_df_diff = [diff[-int(datalen * 0.1):] for diff in data_diff]\n",
    "val_df_diff = [diff[int(datalen * 0.8):-int(datalen * 0.1)] for diff in data_diff]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_index = torch.tensor([\n",
    "    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1], \n",
    "    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0]\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_attr = torch.tensor([\n",
    "    [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [0.5], \n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = TurbineDataset(train_df_base, train_df_diff, target[:int(datalen * 0.8)], edge_index, edge_attr, seq, device)\n",
    "testset = TurbineDataset(test_df_base, test_df_diff, target[-int(datalen * 0.1):], edge_index, edge_attr, seq, device)\n",
    "valset = TurbineDataset(val_df_base, val_df_diff, target[int(datalen * 0.8):-int(datalen * 0.1)], edge_index, edge_attr, seq, device)\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "valloader = DataLoader(valset, batch_size=batch_size)\n",
    "testloader = DataLoader(testset, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128])\n",
      "torch.Size([128, 14, 24, 8])\n",
      "torch.Size([128, 2, 15])\n",
      "torch.Size([128, 15, 24, 1])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "for b, d, ei, et, y in trainloader:\n",
    "    print(b.shape)\n",
    "    print(d.shape)\n",
    "    print(ei.shape)\n",
    "    print(et.shape)\n",
    "    print(y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(features, 64, 1, seq, target_id - 1).to(device)\n",
    "optimizer = Adam(model.parameters(), lr)\n",
    "criterion = MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model, criterion, optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training epoch 1...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1080/1080 [04:11<00:00,  4.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss is 0.07877117298819401\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 99%|█████████▊| 134/136 [00:31<00:00,  4.30it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "stack expects each tensor to be equal size, but got [14, 24, 8] at entry 0 and [14, 23, 8] at entry 105",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtestloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalepoch\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/turbine_model/gnn-frame/trainer.py:25\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, trainloader, valloader, testloader, epoch, valepoch, patience)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain loss is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     24\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m valepoch \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 25\u001b[0m     val_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalloader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval loss is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mval_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m best_val_loss \u001b[38;5;241m>\u001b[39m val_loss:\n",
      "File \u001b[0;32m~/turbine_model/gnn-frame/trainer.py:62\u001b[0m, in \u001b[0;36mTrainer.val_epoch\u001b[0;34m(self, valloader)\u001b[0m\n\u001b[1;32m     60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mval_epoch\u001b[39m(\u001b[38;5;28mself\u001b[39m, valloader:DataLoader):\n\u001b[1;32m     61\u001b[0m     total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m---> 62\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m base, diff, edge_index, edge_attr, y \u001b[38;5;129;01min\u001b[39;00m tqdm(valloader):\n\u001b[1;32m     63\u001b[0m         total_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_step(base, diff, edge_index, edge_attr, y)\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m total_loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mlen\u001b[39m(valloader)\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/tqdm/std.py:1178\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1175\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m   1177\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1178\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m   1179\u001b[0m         \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m   1180\u001b[0m         \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m   1181\u001b[0m         \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    629\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    630\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    635\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    674\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    676\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m    677\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:54\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n\u001b[0;32m---> 54\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollate_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:277\u001b[0m, in \u001b[0;36mdefault_collate\u001b[0;34m(batch)\u001b[0m\n\u001b[1;32m    216\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdefault_collate\u001b[39m(batch):\n\u001b[1;32m    217\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    218\u001b[0m \u001b[38;5;124;03m    Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.\u001b[39;00m\n\u001b[1;32m    219\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    275\u001b[0m \u001b[38;5;124;03m        >>> default_collate(batch)  # Handle `CustomType` automatically\u001b[39;00m\n\u001b[1;32m    276\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 277\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcollate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcollate_fn_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdefault_collate_fn_map\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:144\u001b[0m, in \u001b[0;36mcollate\u001b[0;34m(batch, collate_fn_map)\u001b[0m\n\u001b[1;32m    141\u001b[0m transposed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mbatch))  \u001b[38;5;66;03m# It may be accessed twice, so we use a list.\u001b[39;00m\n\u001b[1;32m    143\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(elem, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[0;32m--> 144\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [collate(samples, collate_fn_map\u001b[38;5;241m=\u001b[39mcollate_fn_map) \u001b[38;5;28;01mfor\u001b[39;00m samples \u001b[38;5;129;01min\u001b[39;00m transposed]  \u001b[38;5;66;03m# Backwards compatibility.\u001b[39;00m\n\u001b[1;32m    145\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    146\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:144\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    141\u001b[0m transposed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mbatch))  \u001b[38;5;66;03m# It may be accessed twice, so we use a list.\u001b[39;00m\n\u001b[1;32m    143\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(elem, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[0;32m--> 144\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[43mcollate\u001b[49m\u001b[43m(\u001b[49m\u001b[43msamples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcollate_fn_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcollate_fn_map\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m samples \u001b[38;5;129;01min\u001b[39;00m transposed]  \u001b[38;5;66;03m# Backwards compatibility.\u001b[39;00m\n\u001b[1;32m    145\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    146\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:121\u001b[0m, in \u001b[0;36mcollate\u001b[0;34m(batch, collate_fn_map)\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m collate_fn_map \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    120\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m elem_type \u001b[38;5;129;01min\u001b[39;00m collate_fn_map:\n\u001b[0;32m--> 121\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcollate_fn_map\u001b[49m\u001b[43m[\u001b[49m\u001b[43melem_type\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcollate_fn_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcollate_fn_map\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    123\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m collate_type \u001b[38;5;129;01min\u001b[39;00m collate_fn_map:\n\u001b[1;32m    124\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(elem, collate_type):\n",
      "File \u001b[0;32m~/miniconda3/envs/gru/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:174\u001b[0m, in \u001b[0;36mcollate_tensor_fn\u001b[0;34m(batch, collate_fn_map)\u001b[0m\n\u001b[1;32m    172\u001b[0m     storage \u001b[38;5;241m=\u001b[39m elem\u001b[38;5;241m.\u001b[39m_typed_storage()\u001b[38;5;241m.\u001b[39m_new_shared(numel, device\u001b[38;5;241m=\u001b[39melem\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m    173\u001b[0m     out \u001b[38;5;241m=\u001b[39m elem\u001b[38;5;241m.\u001b[39mnew(storage)\u001b[38;5;241m.\u001b[39mresize_(\u001b[38;5;28mlen\u001b[39m(batch), \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlist\u001b[39m(elem\u001b[38;5;241m.\u001b[39msize()))\n\u001b[0;32m--> 174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: stack expects each tensor to be equal size, but got [14, 24, 8] at entry 0 and [14, 23, 8] at entry 105"
     ]
    }
   ],
   "source": [
    "trainer.train(trainloader, valloader, testloader, epoch, valepoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_t, y_p = trainer.predict(testloader)\n",
    "print(len(y_p))\n",
    "print(len(y_t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 假设 predictions 和 actual_values 已经定义\n",
    "# predictions 是模型的预测值，actual_values 是真实值\n",
    "\n",
    "actual_values = y_t\n",
    "predictions = y_p\n",
    "print(len(actual_values))\n",
    "length = 1000\n",
    "actual_values = actual_values[:length]\n",
    "predictions = predictions[:length]\n",
    "\n",
    "# 创建散点图\n",
    "plt.scatter(actual_values, predictions, color='blue', label='Actual vs Predicted')\n",
    "\n",
    "# 添加标签和标题\n",
    "plt.xlabel('Actual Values')\n",
    "plt.ylabel('Predicted Values')\n",
    "plt.title('Actual vs Predicted Values')\n",
    "\n",
    "# 添加对角线，用于比较完美预测的情况\n",
    "plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], linestyle='--', color='red', linewidth=2, label='Perfect Prediction')\n",
    "\n",
    "# 显示图例\n",
    "plt.legend()\n",
    "\n",
    "# 显示图形\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mseloss = 0\n",
    "maeloss = 0\n",
    "for t, p in zip(y_t, y_p):\n",
    "    mseloss += (t - p) ** 2\n",
    "    maeloss += abs(t - p)\n",
    "print(mseloss / len(y_t))\n",
    "print(maeloss / len(y_t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "length = 50\n",
    "x = [i for i in range(length)]\n",
    "plt.plot(x, predictions[:length], color=\"red\")\n",
    "plt.plot(x, actual_values[:length])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gru",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
